package com.yeskery.nut.extend.jdbc;

import com.yeskery.nut.core.Holder;
import com.yeskery.nut.transaction.TransactionException;
import com.yeskery.nut.transaction.TransactionManager;
import com.yeskery.nut.util.JdbcUtils;

import javax.sql.DataSource;
import java.lang.reflect.Method;
import java.sql.*;
import java.util.*;
import java.util.function.BiConsumer;
import java.util.stream.Stream;

/**
 * JDBC 操作模板实现类
 * @author sprout
 * 2022-06-02 14:13
 */
public class JdbcTemplateImpl implements JdbcTemplate {

    /** 数据源 */
    private final DataSource dataSource;

    /** 事务管理器 */
    private final TransactionManager transactionManager;

    /** 结果转换器 */
    private final JdbcResultConverter resultConverter = JdbcResultConverter.getInstance();

    /**
     * 构建JDBC 操作模板实现类
     * @param dataSource 数据源
     * @param transactionManager 事务管理器
     */
    public JdbcTemplateImpl(DataSource dataSource, TransactionManager transactionManager) {
        if (dataSource == null) {
            throw new DataAccessException("DataSource Must Not Be Null.");
        }
        this.dataSource = dataSource;
        if (transactionManager == null) {
            throw new DataAccessException("TransactionManager Must Not Be Null.");
        }
        this.transactionManager = transactionManager;
    }

    @Override
    public DataSource getDataSource() {
        return dataSource;
    }

    @Override
    public <T> T execute(ConnectionCallback<T> callback) {
        Connection connection = null;
        try {
            connection = getConnection();
            return callback.callback(connection);
        } catch (SQLException e) {
            throw new DataAccessException("JDBC SQL Execute Fail.", e);
        } finally {
            freeConnection(connection);
        }
    }

    @Override
    public <T> T execute(StatementCallback<T> callback) {
        Connection connection = null;
        Statement statement = null;
        try {
            connection = getConnection();
            statement = connection.createStatement();
            return callback.callback(statement);
        } catch (SQLException e) {
            throw new DataAccessException("JDBC SQL Execute Fail.", e);
        } finally {
            freeConnection(connection, statement);
        }
    }

    @Override
    public <T> T execute(String sql, PreparedStatementCallback<T> callback) {
        Connection connection = null;
        PreparedStatement pstmt = null;
        try {
            connection = getConnection();
            pstmt = connection.prepareStatement(sql);
            return callback.callback(pstmt);
        } catch (SQLException e) {
            throw new DataAccessException("JDBC SQL Execute Fail.", e);
        } finally {
            freeConnection(connection, pstmt);
        }
    }

    @Override
    public <T> T query(String sql, ResultSetRowCallback<T> callback) {
        Holder<T> holder = new Holder<>();
        try {
            query(sql, getQueryOneResultSetHandler(holder, callback));
        } catch (DataNotFoundException e) {
            return null;
        }
        return holder.getObject();
    }

    @Override
    public <T> T query(String sql, Object[] params, ResultSetRowCallback<T> callback) {
        Holder<T> holder = new Holder<>();
        try {
            query(sql, params, getQueryOneResultSetHandler(holder, callback));
        } catch (DataNotFoundException e) {
            return null;
        }
        return holder.getObject();
    }

    @Override
    public void query(String sql, ResultSetHandler handler) {
        Connection connection = null;
        Statement statement = null;
        ResultSet resultSet = null;
        try {
            connection = getConnection();
            statement = connection.createStatement();
            resultSet = statement.executeQuery(sql);
            handler.handle(resultSet);
        } catch (SQLException e) {
            throw new DataAccessException("JDBC SQL Execute Fail.", e);
        } finally {
            freeConnection(connection, statement, resultSet);
        }
    }

    @Override
    public void query(String sql, Object[] params, ResultSetHandler handler) {
        Connection connection = null;
        PreparedStatement pstmt = null;
        ResultSet resultSet = null;
        try {
            connection = getConnection();
            pstmt = connection.prepareStatement(sql);
            for (int i = 0; i < params.length; i++) {
                pstmt.setObject(i + 1, params[i]);
            }
            resultSet = pstmt.executeQuery();
            handler.handle(resultSet);
        } catch (SQLException e) {
            throw new DataAccessException("JDBC SQL Execute Fail.", e);
        } finally {
            freeConnection(connection, pstmt, resultSet);
        }
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> T query(String sql, Class<T> clazz) {
        return query(sql, rs -> (T) JdbcUtils.getResultSetValue(rs, 1, clazz));
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> T query(String sql, Object[] params, Class<T> clazz) {
        return query(sql, params, rs -> (T) JdbcUtils.getResultSetValue(rs, 1, clazz));
    }

    @Override
    public <T> T queryForObject(String sql, ColumnNameStrategy columnNameStrategy, Class<T> clazz) {
        try {
            return doQueryObject(columnNameStrategy, clazz, (map, t) -> doAppendSingleObject(sql, map, t));
        } catch (DataNotFoundException e) {
            return null;
        }
    }

    @Override
    public <T> T queryForObject(String sql, Object[] params, ColumnNameStrategy columnNameStrategy, Class<T> clazz) {
        try {
            return doQueryObject(columnNameStrategy, clazz, (map, t) -> doAppendSingleObject(sql, params, map, t));
        } catch (DataNotFoundException e) {
            return null;
        }
    }

    @Override
    public <T> List<T> queryForList(String sql, ResultSetRowCallback<T> callback) {
        List<T> list = new LinkedList<>();
        query(sql, rs -> {
            while (rs.next()) {
                list.add(callback.callback(rs));
            }
        });
        return list;
    }

    @Override
    public <T> List<T> queryForList(String sql, Object[] params, ResultSetRowCallback<T> callback) {
        List<T> list = new LinkedList<>();
        query(sql, params, rs -> {
            while (rs.next()) {
                list.add(callback.callback(rs));
            }
        });
        return list;
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> List<T> queryForList(String sql, Class<T> clazz) {
        return (List<T>) queryForList(sql, (ResultSetRowCallback<?>) rs -> JdbcUtils.getResultSetValue(rs, 1, clazz));
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> List<T> queryForList(String sql, Object[] params, Class<T> clazz) {
        return (List<T>) queryForList(sql, params, (ResultSetRowCallback<?>) rs -> JdbcUtils.getResultSetValue(rs, 1, clazz));
    }

    @Override
    public <T> List<T> queryForObjectList(String sql, ColumnNameStrategy columnNameStrategy, Class<T> clazz) {
        List<T> list = new LinkedList<>();
        Map<String, Method> setMethodMap = resultConverter.getSetMethodMap(columnNameStrategy, clazz);
        query(sql, rs -> {
            while (rs.next()) {
                T t = resultConverter.createObject(clazz);
                resultConverter.appendObjectFieldFromResultSet(rs, rs.getMetaData(), setMethodMap, t);
                list.add(t);
            }
        });
        return list;
    }

    @Override
    public <T> List<T> queryForObjectList(String sql, Object[] params, ColumnNameStrategy columnNameStrategy, Class<T> clazz) {
        List<T> list = new LinkedList<>();
        Map<String, Method> setMethodMap = resultConverter.getSetMethodMap(columnNameStrategy, clazz);
        query(sql, params, rs -> {
            while (rs.next()) {
                T t = resultConverter.createObject(clazz);
                resultConverter.appendObjectFieldFromResultSet(rs, rs.getMetaData(), setMethodMap, t);
                list.add(t);
            }
        });
        return list;
    }

    @Override
    public Map<String, Object> queryForMap(String sql) {
        Map<String, Object> map = new HashMap<>();
        try {
            doAppendSingleRowMap(sql, map);
        } catch (DataNotFoundException e) {
            return null;
        }
        return map;
    }

    @Override
    public Map<String, Object> queryForMap(String sql, Object[] params) {
        Map<String, Object> map = new HashMap<>();
        try {
            doAppendSingleRowMap(sql, params, map);
        } catch (DataNotFoundException e) {
            return null;
        }
        return map;
    }

    @Override
    public List<Map<String, Object>> queryForMapList(String sql) {
        List<Map<String, Object>> list = new LinkedList<>();
        query(sql, rs -> {
            while (rs.next()) {
                Map<String, Object> map = new HashMap<>();
                resultConverter.appendMapValueFromResultSet(rs, rs.getMetaData(), map);
                list.add(map);
            }
        });
        return list;
    }

    @Override
    public List<Map<String, Object>> queryForMapList(String sql, Object[] params) {
        List<Map<String, Object>> list = new LinkedList<>();
        query(sql, params, rs -> {
            while (rs.next()) {
                Map<String, Object> map = new HashMap<>();
                resultConverter.appendMapValueFromResultSet(rs, rs.getMetaData(), map);
                list.add(map);
            }
        });
        return list;
    }

    @Override
    public <T> Stream<T> queryForObjectStream(String sql, ColumnNameStrategy columnNameStrategy, Class<T> clazz) {
        Map<String, Method> setMethodMap = resultConverter.getSetMethodMap(columnNameStrategy, clazz);
        return queryForStream(sql, rs -> {
            T t = resultConverter.createObject(clazz);
            resultConverter.appendObjectFieldFromResultSet(rs, rs.getMetaData(), setMethodMap, t);
            return t;
        });
    }

    @Override
    public <T> Stream<T> queryForObjectStream(String sql, Object[] params, ColumnNameStrategy columnNameStrategy, Class<T> clazz) {
        Map<String, Method> setMethodMap = resultConverter.getSetMethodMap(columnNameStrategy, clazz);
        return queryForStream(sql, params, rs -> {
            T t = resultConverter.createObject(clazz);
            resultConverter.appendObjectFieldFromResultSet(rs, rs.getMetaData(), setMethodMap, t);
            return t;
        });
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> Stream<T> queryForStream(String sql, Class<T> clazz) {
        return queryForStream(sql, rs -> (T) JdbcUtils.getResultSetValue(rs, 1, clazz));
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> Stream<T> queryForStream(String sql, Object[] params, Class<T> clazz) {
        return queryForStream(sql, params, rs -> (T) JdbcUtils.getResultSetValue(rs, 1, clazz));
    }

    @Override
    public <T> Stream<T> queryForStream(String sql, ResultSetRowCallback<T> callback) {
        try {
            Connection connection = getConnection();
            Statement statement = connection.createStatement();
            ResultSet resultSet = statement.executeQuery(sql);
            return new ResultSetSpliterator<>(resultSet, callback).stream()
                    .onClose(() -> freeConnection(connection, statement, resultSet));
        } catch (SQLException e) {
            throw new DataAccessException("JDBC SQL Execute Fail.", e);
        }
    }

    @Override
    public <T> Stream<T> queryForStream(String sql, Object[] params, ResultSetRowCallback<T> callback) {
        try {
            Connection connection = getConnection();
            PreparedStatement pstmt = connection.prepareStatement(sql);
            for (int i = 0; i < params.length; i++) {
                pstmt.setObject(i + 1, params[i]);
            }
            ResultSet resultSet = pstmt.executeQuery();
            return new ResultSetSpliterator<>(resultSet, callback).stream()
                    .onClose(() -> freeConnection(connection, pstmt, resultSet));
        } catch (SQLException e) {
            throw new DataAccessException("JDBC SQL Execute Fail.", e);
        }
    }

    @Override
    public Row queryForRow(String sql) {
        RowImpl row = new RowImpl();
        try {
            doAppendSingleRowMap(sql, row);
        } catch (DataNotFoundException e) {
            return null;
        }
        return row;
    }

    @Override
    public Row queryForRow(String sql, Object[] params) {
        RowImpl row = new RowImpl();
        try {
            doAppendSingleRowMap(sql, params, row);
        } catch (DataNotFoundException e) {
            return null;
        }
        return row;
    }

    @Override
    public List<Row> queryForRowList(String sql) {
        List<Row> list = new LinkedList<>();
        query(sql, rs -> {
            while (rs.next()) {
                RowImpl row = new RowImpl();
                resultConverter.appendMapValueFromResultSet(rs, rs.getMetaData(), row);
                list.add(row);
            }
        });
        return list;
    }

    @Override
    public List<Row> queryForRowList(String sql, Object[] params) {
        List<Row> list = new LinkedList<>();
        query(sql, params, rs -> {
            while (rs.next()) {
                RowImpl row = new RowImpl();
                resultConverter.appendMapValueFromResultSet(rs, rs.getMetaData(), row);
                list.add(row);
            }
        });
        return list;
    }

    @Override
    public JdbcSqlBuilder sqlBuilder() {
        return new JdbcSqlBuilder(this);
    }

    @Override
    public void startTransaction() throws TransactionException {
        transactionManager.startTransaction();
    }

    @Override
    public void commitTransaction() throws TransactionException {
        transactionManager.commitTransaction();
    }

    @Override
    public void rollbackTransaction() throws TransactionException {
        transactionManager.rollbackTransaction();
    }

    @Override
    public boolean isActive() {
        return transactionManager.getCurrentTransactionOptional().isPresent();
    }

    /**
     * 获取查询单个对象结果集处理器
     * @param holder 结果对象
     * @param callback 回调处理函数
     * @param <T> 结果对象类型
     * @return 结果集处理器
     */
    private <T> ResultSetHandler getQueryOneResultSetHandler(Holder<T> holder, ResultSetRowCallback<T> callback) {
        return rs -> {
            T obj;
            if (rs.next()) {
                obj = callback.callback(rs);
            } else {
                throw new DataNotFoundException("Data Not Found.");
            }
            if (rs.next()) {
                throw new MultiRowsFoundException("Multi Value Found.");
            }
            holder.setObject(obj);
        };
    }

    /**
     * 执行查询对象操作
     * @param columnNameStrategy 列名策略
     * @param clazz 查询对象类型
     * @param consumer 回调处理函数
     * @param <T> 查询对象类型
     * @return 查询对象
     */
    private <T> T doQueryObject(ColumnNameStrategy columnNameStrategy, Class<T> clazz, BiConsumer<Map<String, Method>, T> consumer) {
        T t = resultConverter.createObject(clazz);
        consumer.accept(resultConverter.getSetMethodMap(columnNameStrategy, clazz), t);
        return t;
    }

    /**
     * 从结果集填充单行值到object中
     * @param sql sql
     * @param setMethodMap set方法map
     * @param object 目标对象
     */
    private void doAppendSingleObject(String sql, Map<String, Method> setMethodMap, Object object) {
        query(sql, rs -> {
            if (rs.next()) {
                resultConverter.appendObjectFieldFromResultSet(rs, rs.getMetaData(), setMethodMap, object);
            } else {
                throw new DataNotFoundException("Data Not Found.");
            }
            if (rs.next()) {
                throw new MultiRowsFoundException("Multi Value Found.");
            }
        });
    }

    /**
     * 从结果集填充单行值到object中
     * @param sql sql
     * @param params 参数
     * @param setMethodMap set方法map
     * @param object 目标对象
     */
    private void doAppendSingleObject(String sql, Object[] params, Map<String, Method> setMethodMap, Object object) {
        query(sql, params, rs -> {
            if (rs.next()) {
                resultConverter.appendObjectFieldFromResultSet(rs, rs.getMetaData(), setMethodMap, object);
            } else {
                throw new DataNotFoundException("Data Not Found.");
            }
            if (rs.next()) {
                throw new MultiRowsFoundException("Multi Value Found.");
            }
        });
    }

    /**
     * 从结果集填充单行值到map中
     * @param sql sql
     * @param map 结果集映射对象
     */
    private void doAppendSingleRowMap(String sql, Map<String, Object> map) {
        query(sql, rs -> {
            if (rs.next()) {
                resultConverter.appendMapValueFromResultSet(rs, rs.getMetaData(), map);
            } else {
                throw new DataNotFoundException("Data Not Found.");
            }
            if (rs.next()) {
                throw new MultiRowsFoundException("Multi Value Found.");
            }
        });
    }

    /**
     * 从结果集填充单行值到map中
     * @param sql sql
     * @param params 参数
     * @param map 结果集映射对象
     */
    private void doAppendSingleRowMap(String sql, Object[] params, Map<String, Object> map) {
        query(sql, params, rs -> {
            if (rs.next()) {
                resultConverter.appendMapValueFromResultSet(rs, rs.getMetaData(), map);
            } else {
                throw new DataNotFoundException("Data Not Found.");
            }
            if (rs.next()) {
                throw new MultiRowsFoundException("Multi Value Found.");
            }
        });
    }

    /**
     * 释放连接对象
     * @param connection 连接对象
     * @param stmt Statement对象
     * @param resultSet 结果集对象
     */
    private void freeConnection(Connection connection, Statement stmt, ResultSet resultSet) {
        JdbcUtils.closeResultSet(resultSet);
        JdbcUtils.closeStatement(stmt);
        if (dataSource instanceof WrapBasicDataSource) {
            JdbcUtils.closeConnection(connection);
        } else if (dataSource instanceof WrapCacheDataSource) {
            if (!transactionManager.getCurrentTransactionOptional().isPresent()) {
                ((WrapCacheDataSource) dataSource).free(connection);
            }
        } else {
            JdbcUtils.closeConnection(connection);
        }
    }

    /**
     * 释放连接对象
     * @param connection 连接对象
     */
    private void freeConnection(Connection connection) {
        freeConnection(connection, null, null);
    }

    /**
     * 释放连接对象
     * @param connection 连接对象
     * @param stmt Statement对象
     */
    private void freeConnection(Connection connection, Statement stmt) {
        freeConnection(connection, stmt, null);
    }

    /**
     * 获取连接对象
     */
    private Connection getConnection() {
        Optional<Connection> optional = transactionManager.getCurrentConnectionOptional();
        if (optional.isPresent()) {
            return optional.get();
        }
        try {
            return dataSource.getConnection();
        } catch (SQLException e) {
            throw new TransactionException(e);
        }
    }
}
